{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1Z6Wtb_jisbA"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "QUyRGn9riopB"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "H1yCdGFW4j_F"
   },
   "source": [
    "# 预创建的 Estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PS6_yKSoyLAl"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/estimator/premade\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorFlow.google.cn 上查看</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 中运行</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/estimator/premade.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载 notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FgdA9XE5ZCS3"
   },
   "source": [
    "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
    "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
    "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
    "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R4YZ_ievcY7p"
   },
   "source": [
    "\n",
    "本教程将向您展示如何使用 Estimators 解决 Tensorflow 中的鸢尾花（Iris）分类问题。Estimator 是 Tensorflow 完整模型的高级表示，它被设计用于轻松扩展和异步训练。更多细节请参阅 [Estimators](https://tensorflow.google.cn/guide/estimator)。\n",
    "\n",
    "请注意，在 Tensorflow 2.0 中，[Keras API](https://tensorflow.google.cn/guide/keras) 可以完成许多相同的任务，而且被认为是一个更易学习的API。如果您刚刚开始入门，我们建议您从 Keras 开始。有关 Tensorflow 2.0 中可用高级API的更多信息，请参阅 [Keras标准化](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a)。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8IFct0yedsTy"
   },
   "source": [
    "## 首先要做的事\n",
    "\n",
    "为了开始，您将首先导入 Tensorflow 和一系列您需要的库。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jPo5bQwndr9P"
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "try:\n",
    "  # Colab only\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "c5w4m5gncnGh"
   },
   "source": [
    "## 数据集\n",
    "\n",
    "本文档中的示例程序构建并测试了一个模型，该模型根据[花萼](https://en.wikipedia.org/wiki/Sepal)和[花瓣](https://en.wikipedia.org/wiki/Petal)的大小将鸢尾花分成三种物种。\n",
    "\n",
    "您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个[标签](https://developers.google.com/machine-learning/glossary/#label)。这四个特征确定了单个鸢尾花的以下植物学特征：\n",
    "\n",
    "* 花萼长度\n",
    "* 花萼宽度\n",
    "* 花瓣长度\n",
    "* 花瓣宽度\n",
    "\n",
    "根据这些信息，您可以定义一些有用的常量来解析数据：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lSyrXp_He_UE"
   },
   "outputs": [],
   "source": [
    "CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']\n",
    "SPECIES = ['Setosa', 'Versicolor', 'Virginica']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "j6mTfIQzfC9w"
   },
   "source": [
    "接下来，使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PumyCN8VdGGc"
   },
   "outputs": [],
   "source": [
    "train_path = tf.keras.utils.get_file(\n",
    "    \"iris_training.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\")\n",
    "test_path = tf.keras.utils.get_file(\n",
    "    \"iris_test.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\")\n",
    "\n",
    "train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)\n",
    "test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wHFxNLszhQjz"
   },
   "source": [
    "通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WOJt-ML4hAwI"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SepalLength</th>\n",
       "      <th>SepalWidth</th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "      <th>Species</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.4</td>\n",
       "      <td>2.8</td>\n",
       "      <td>5.6</td>\n",
       "      <td>2.2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5.0</td>\n",
       "      <td>2.3</td>\n",
       "      <td>3.3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.9</td>\n",
       "      <td>2.5</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.7</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.7</td>\n",
       "      <td>3.8</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.3</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   SepalLength  SepalWidth  PetalLength  PetalWidth  Species\n",
       "0          6.4         2.8          5.6         2.2        2\n",
       "1          5.0         2.3          3.3         1.0        1\n",
       "2          4.9         2.5          4.5         1.7        2\n",
       "3          4.9         3.1          1.5         0.1        0\n",
       "4          5.7         3.8          1.7         0.3        0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jQJEYfVvfznP"
   },
   "source": [
    "对于每个数据集都分割出标签，模型将被训练来预测这些标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sSaJNGeaZCTG"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SepalLength</th>\n",
       "      <th>SepalWidth</th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.4</td>\n",
       "      <td>2.8</td>\n",
       "      <td>5.6</td>\n",
       "      <td>2.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>5.0</td>\n",
       "      <td>2.3</td>\n",
       "      <td>3.3</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.9</td>\n",
       "      <td>2.5</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.7</td>\n",
       "      <td>3.8</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   SepalLength  SepalWidth  PetalLength  PetalWidth\n",
       "0          6.4         2.8          5.6         2.2\n",
       "1          5.0         2.3          3.3         1.0\n",
       "2          4.9         2.5          4.5         1.7\n",
       "3          4.9         3.1          1.5         0.1\n",
       "4          5.7         3.8          1.7         0.3"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_y = train.pop('Species')\n",
    "test_y = test.pop('Species')\n",
    "\n",
    "# 标签列现已从数据中删除\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jZx1L_1Vcmxv"
   },
   "source": [
    "## Estimator 编程概述\n",
    "\n",
    "现在您已经设定好了数据，您可以使用 Tensorflow Estimator 定义模型。Estimator 是从 `tf.estimator.Estimator` 中派生的任何类。Tensorflow提供了一组`tf.estimator`(例如，`LinearRegressor`)来实现常见的机器学习算法。此外，您可以编写您自己的[自定义 Estimator](https://tensorflow.google.cn/guide/custom_estimators)。入门阶段我们建议使用预创建的 Estimator。\n",
    "\n",
    "为了编写基于预创建的 Estimator 的 Tensorflow 项目，您必须完成以下工作：\n",
    "\n",
    "* 创建一个或多个输入函数\n",
    "* 定义模型的特征列\n",
    "* 实例化一个 Estimator，指定特征列和各种超参数。\n",
    "* 在 Estimator 对象上调用一个或多个方法，传递合适的输入函数以作为数据源。\n",
    "\n",
    "我们来看看这些任务是如何在鸢尾花分类中实现的。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2OcguDfBcmmg"
   },
   "source": [
    "## 创建输入函数\n",
    "\n",
    "您必须创建输入函数来提供用于训练、评估和预测的数据。\n",
    "\n",
    "**输入函数**是一个返回 `tf.data.Dataset` 对象的函数，此对象会输出下列含两个元素的元组：\n",
    "\n",
    "* [`features`](https://developers.google.com/machine-learning/glossary/#feature)——Python字典，其中：\n",
    "    * 每个键都是特征名称\n",
    "    * 每个值都是包含此特征所有值的数组\n",
    "* `label` 包含每个样本的[标签](https://developers.google.com/machine-learning/glossary/#label)的值的数组。\n",
    "\n",
    "为了向您展示输入函数的格式，请查看下面这个简单的实现：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nzr5vRr5caGF"
   },
   "outputs": [],
   "source": [
    "def input_evaluation_set():\n",
    "    features = {'SepalLength': np.array([6.4, 5.0]),\n",
    "                'SepalWidth':  np.array([2.8, 2.3]),\n",
    "                'PetalLength': np.array([5.6, 3.3]),\n",
    "                'PetalWidth':  np.array([2.2, 1.0])}\n",
    "    labels = np.array([2, 1])\n",
    "    return features, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NpXvGjfnjHgY"
   },
   "source": [
    "\n",
    "您的输入函数可以以您喜欢的方式生成 `features` 字典与 `label` 列表。但是，我们建议使用 Tensorflow 的 [Dataset API](https://tensorflow.google.cn/guide/datasets)，该 API 可以用来解析各种类型的数据。\n",
    "\n",
    "Dataset API 可以为您处理很多常见情况。例如，使用 Dataset API，您可以轻松地从大量文件中并行读取记录，并将它们合并为单个数据流。\n",
    "\n",
    "为了简化此示例，我们将使用 [pandas](https://pandas.pydata.org/) 加载数据，并利用此内存数据构建输入管道。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T20u1anCi8NP"
   },
   "outputs": [],
   "source": [
    "def input_fn(features, labels, training=True, batch_size=256):\n",
    "    \"\"\"An input function for training or evaluating\"\"\"\n",
    "    # 将输入转换为数据集。\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))\n",
    "\n",
    "    # 如果在训练模式下混淆并重复数据。\n",
    "    if training:\n",
    "        dataset = dataset.shuffle(1000).repeat()\n",
    "    \n",
    "    return dataset.batch(batch_size)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xIwcFT4MlZEi"
   },
   "source": [
    "## 定义特征列（feature columns）\n",
    "\n",
    "[**特征列（feature columns）**](https://developers.google.com/machine-learning/glossary/#feature_columns)是一个对象，用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候，您会向其传递一个特征列的列表，其中包含您希望模型使用的每个特征。`tf.feature_column` 模块提供了许多为模型表示数据的选项。\n",
    "\n",
    "对于鸢尾花问题，4 个原始特征是数值，因此我们将构建一个特征列的列表，以告知 Estimator 模型将 4 个特征都表示为 32 位浮点值。故创建特征列的代码如下所示：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZTTriO8FlSML"
   },
   "outputs": [],
   "source": [
    "# 特征列描述了如何使用输入。\n",
    "my_feature_columns = []\n",
    "for key in train.keys():\n",
    "    my_feature_columns.append(tf.feature_column.numeric_column(key=key))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jpKkhMoZljco"
   },
   "source": [
    "特征列可能比上述示例复杂得多。您可以从[指南](https://tensorflow.google.cn/guide/feature_columns)获取更多关于特征列的信息。\n",
    "\n",
    "我们已经介绍了如何使模型表示原始特征，现在您可以构建 Estimator 了。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kuE59XHEl22K"
   },
   "source": [
    "## 实例化 Estimator\n",
    "\n",
    "鸢尾花为题是一个经典的分类问题。幸运的是，Tensorflow 提供了几个预创建的 Estimator 分类器，其中包括：\n",
    "\n",
    "* `tf.estimator.DNNClassifier` 用于多类别分类的深度模型\n",
    "* `tf.estimator.DNNLinearCombinedClassifier` 用于广度与深度模型\n",
    "* `tf.estimator.LinearClassifier` 用于基于线性模型的分类器\n",
    "\n",
    "对于鸢尾花问题，`tf.estimator.DNNClassifier` 似乎是最好的选择。您可以这样实例化该 Estimator：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qnf4o2V5lcPn"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'C:\\\\Users\\\\Awebone\\\\AppData\\\\Local\\\\Temp\\\\tmpst2po1mq', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000025622114D68>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
     ]
    }
   ],
   "source": [
    "# 构建一个拥有两个隐层，隐藏节点分别为 30 和 10 的深度神经网络。\n",
    "classifier = tf.estimator.DNNClassifier(\n",
    "    feature_columns=my_feature_columns,\n",
    "    # 隐层所含结点数量分别为 30 和 10.\n",
    "    hidden_units=[30, 10],\n",
    "    # 模型必须从三个类别中做出选择。\n",
    "    n_classes=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tzzt5nUpmEe3"
   },
   "source": [
    " ## 训练、评估和预测\n",
    "\n",
    "我们已经有一个 Estimator 对象，现在可以调用方法来执行下列操作：\n",
    "\n",
    "* 训练模型。\n",
    "* 评估经过训练的模型。\n",
    "* 使用经过训练的模型进行预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rnihuLdWmE75"
   },
   "source": [
    "### 训练模型\n",
    "\n",
    "通过调用 Estimator 的 `Train` 方法来训练模型，如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4jW08YtPl1iS"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n",
      "WARNING:tensorflow:From f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\training\\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "WARNING:tensorflow:From f:\\condaenv\\tf2_py36\\lib\\site-packages\\tensorflow_core\\python\\keras\\optimizer_v2\\adagrad.py:108: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\\model.ckpt.\n",
      "INFO:tensorflow:loss = 1.2809983, step = 0\n",
      "INFO:tensorflow:global_step/sec: 194.954\n",
      "INFO:tensorflow:loss = 0.99874, step = 100 (0.515 sec)\n",
      "INFO:tensorflow:global_step/sec: 238.914\n",
      "INFO:tensorflow:loss = 0.94586504, step = 200 (0.417 sec)\n",
      "INFO:tensorflow:global_step/sec: 213.725\n",
      "INFO:tensorflow:loss = 0.9394318, step = 300 (0.468 sec)\n",
      "INFO:tensorflow:global_step/sec: 195.754\n",
      "INFO:tensorflow:loss = 0.92686695, step = 400 (0.515 sec)\n",
      "INFO:tensorflow:global_step/sec: 233.154\n",
      "INFO:tensorflow:loss = 0.91447425, step = 500 (0.427 sec)\n",
      "INFO:tensorflow:global_step/sec: 211.869\n",
      "INFO:tensorflow:loss = 0.8933535, step = 600 (0.470 sec)\n",
      "INFO:tensorflow:global_step/sec: 223.204\n",
      "INFO:tensorflow:loss = 0.8631328, step = 700 (0.449 sec)\n",
      "INFO:tensorflow:global_step/sec: 228.993\n",
      "INFO:tensorflow:loss = 0.8349169, step = 800 (0.436 sec)\n",
      "INFO:tensorflow:global_step/sec: 229.278\n",
      "INFO:tensorflow:loss = 0.8111194, step = 900 (0.452 sec)\n",
      "INFO:tensorflow:global_step/sec: 231.125\n",
      "INFO:tensorflow:loss = 0.78953296, step = 1000 (0.417 sec)\n",
      "INFO:tensorflow:global_step/sec: 232.271\n",
      "INFO:tensorflow:loss = 0.77068985, step = 1100 (0.431 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.71\n",
      "INFO:tensorflow:loss = 0.7517746, step = 1200 (0.415 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.5\n",
      "INFO:tensorflow:loss = 0.73730665, step = 1300 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 215.565\n",
      "INFO:tensorflow:loss = 0.7229214, step = 1400 (0.466 sec)\n",
      "INFO:tensorflow:global_step/sec: 227.576\n",
      "INFO:tensorflow:loss = 0.7071636, step = 1500 (0.437 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.292\n",
      "INFO:tensorflow:loss = 0.6940389, step = 1600 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 243.703\n",
      "INFO:tensorflow:loss = 0.68749726, step = 1700 (0.410 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.619\n",
      "INFO:tensorflow:loss = 0.67062706, step = 1800 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 247.545\n",
      "INFO:tensorflow:loss = 0.66153854, step = 1900 (0.420 sec)\n",
      "INFO:tensorflow:global_step/sec: 241.551\n",
      "INFO:tensorflow:loss = 0.6498587, step = 2000 (0.398 sec)\n",
      "INFO:tensorflow:global_step/sec: 239.355\n",
      "INFO:tensorflow:loss = 0.63856035, step = 2100 (0.418 sec)\n",
      "INFO:tensorflow:global_step/sec: 238.798\n",
      "INFO:tensorflow:loss = 0.63188624, step = 2200 (0.419 sec)\n",
      "INFO:tensorflow:global_step/sec: 241.708\n",
      "INFO:tensorflow:loss = 0.6198819, step = 2300 (0.414 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.268\n",
      "INFO:tensorflow:loss = 0.611103, step = 2400 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 220.624\n",
      "INFO:tensorflow:loss = 0.5996893, step = 2500 (0.453 sec)\n",
      "INFO:tensorflow:global_step/sec: 235.848\n",
      "INFO:tensorflow:loss = 0.5879957, step = 2600 (0.426 sec)\n",
      "INFO:tensorflow:global_step/sec: 238.588\n",
      "INFO:tensorflow:loss = 0.5826921, step = 2700 (0.433 sec)\n",
      "INFO:tensorflow:global_step/sec: 225.018\n",
      "INFO:tensorflow:loss = 0.5788972, step = 2800 (0.429 sec)\n",
      "INFO:tensorflow:global_step/sec: 239.223\n",
      "INFO:tensorflow:loss = 0.57051224, step = 2900 (0.418 sec)\n",
      "INFO:tensorflow:global_step/sec: 247.012\n",
      "INFO:tensorflow:loss = 0.5632386, step = 3000 (0.405 sec)\n",
      "INFO:tensorflow:global_step/sec: 249.278\n",
      "INFO:tensorflow:loss = 0.55913097, step = 3100 (0.417 sec)\n",
      "INFO:tensorflow:global_step/sec: 241.274\n",
      "INFO:tensorflow:loss = 0.5421753, step = 3200 (0.399 sec)\n",
      "INFO:tensorflow:global_step/sec: 239.366\n",
      "INFO:tensorflow:loss = 0.53468925, step = 3300 (0.433 sec)\n",
      "INFO:tensorflow:global_step/sec: 239.884\n",
      "INFO:tensorflow:loss = 0.5349634, step = 3400 (0.401 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.631\n",
      "INFO:tensorflow:loss = 0.5298414, step = 3500 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.28\n",
      "INFO:tensorflow:loss = 0.52088463, step = 3600 (0.416 sec)\n",
      "INFO:tensorflow:global_step/sec: 240.076\n",
      "INFO:tensorflow:loss = 0.5163662, step = 3700 (0.417 sec)\n",
      "INFO:tensorflow:global_step/sec: 231.71\n",
      "INFO:tensorflow:loss = 0.500163, step = 3800 (0.432 sec)\n",
      "INFO:tensorflow:global_step/sec: 239.87\n",
      "INFO:tensorflow:loss = 0.50416887, step = 3900 (0.417 sec)\n",
      "INFO:tensorflow:global_step/sec: 233.058\n",
      "INFO:tensorflow:loss = 0.508107, step = 4000 (0.429 sec)\n",
      "INFO:tensorflow:global_step/sec: 236.043\n",
      "INFO:tensorflow:loss = 0.48513377, step = 4100 (0.424 sec)\n",
      "INFO:tensorflow:global_step/sec: 241.88\n",
      "INFO:tensorflow:loss = 0.4869864, step = 4200 (0.413 sec)\n",
      "INFO:tensorflow:global_step/sec: 231.288\n",
      "INFO:tensorflow:loss = 0.48726383, step = 4300 (0.432 sec)\n",
      "INFO:tensorflow:global_step/sec: 215.615\n",
      "INFO:tensorflow:loss = 0.48555857, step = 4400 (0.464 sec)\n",
      "INFO:tensorflow:global_step/sec: 209.3\n",
      "INFO:tensorflow:loss = 0.4740008, step = 4500 (0.481 sec)\n",
      "INFO:tensorflow:global_step/sec: 226.79\n",
      "INFO:tensorflow:loss = 0.47090137, step = 4600 (0.438 sec)\n",
      "INFO:tensorflow:global_step/sec: 232.56\n",
      "INFO:tensorflow:loss = 0.46782875, step = 4700 (0.446 sec)\n",
      "INFO:tensorflow:global_step/sec: 229.564\n",
      "INFO:tensorflow:loss = 0.4667447, step = 4800 (0.420 sec)\n",
      "INFO:tensorflow:global_step/sec: 216.127\n",
      "INFO:tensorflow:loss = 0.45524535, step = 4900 (0.463 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 5000 into C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\\model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.45230627.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x25622114860>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型。\n",
    "classifier.train(\n",
    "    input_fn=lambda: input_fn(train, train_y, training=True),\n",
    "    steps=5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ybiTFDmlmes8"
   },
   "source": [
    "注意将 ` input_fn` 调用封装在 [`lambda`](https://docs.python.org/3/tutorial/controlflow.html) 中以获取参数，同时提供不带参数的输入函数，如 Estimator 所预期的那样。`step` 参数告知该方法在训练多少步后停止训练。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HNvJLH8hmsdf"
   },
   "source": [
    "### 评估经过训练的模型\n",
    "\n",
    "现在模型已经经过训练，您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率（accuracy）进行评估：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "A169XuO4mKxF"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Starting evaluation at 2020-02-04T21:51:31Z\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\\model.ckpt-5000\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Finished evaluation at 2020-02-04-21:51:31\n",
      "INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.9, average_loss = 0.518781, global_step = 5000, loss = 0.518781\n",
      "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\\model.ckpt-5000\n",
      "\n",
      "Test set accuracy: 0.900\n",
      "\n"
     ]
    }
   ],
   "source": [
    "eval_result = classifier.evaluate(\n",
    "    input_fn=lambda: input_fn(test, test_y, training=False))\n",
    "\n",
    "print('\\nTest set accuracy: {accuracy:0.3f}\\n'.format(**eval_result))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VnPMP5EHph17"
   },
   "source": [
    "与对 `train` 方法的调用不同，我们没有传递 `steps` 参数来进行评估。用于评估的 `input_fn` 只生成一个 [epoch](https://developers.google.com/machine-learning/glossary/#epoch) 的数据。\n",
    "\n",
    "`eval_result` 字典亦包含 `average_loss`（每个样本的平均误差），`loss`（每个 mini-batch 的平均误差）与 Estimator 的 `global_step`（经历的训练迭代次数）值。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ur624ibpp52X"
   },
   "source": [
    "### 利用经过训练的模型进行预测（推理）\n",
    "\n",
    "我们已经有一个经过训练的模型，可以生成准确的评估结果。我们现在可以使用经过训练的模型，根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样，我们使用单个函数调用进行预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wltc0jpgng38"
   },
   "outputs": [],
   "source": [
    "# 由模型生成预测\n",
    "expected = ['Setosa', 'Versicolor', 'Virginica']\n",
    "predict_x = {\n",
    "    'SepalLength': [5.1, 5.9, 6.9],\n",
    "    'SepalWidth': [3.3, 3.0, 3.1],\n",
    "    'PetalLength': [1.7, 4.2, 5.4],\n",
    "    'PetalWidth': [0.5, 1.5, 2.1],\n",
    "}\n",
    "\n",
    "def input_fn(features, batch_size=256):\n",
    "    \"\"\"An input function for prediction.\"\"\"\n",
    "    # 将输入转换为无标签数据集。\n",
    "    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)\n",
    "\n",
    "predictions = classifier.predict(\n",
    "    input_fn=lambda: input_fn(predict_x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JsETKQo0rHvi"
   },
   "source": [
    "`predict` 方法返回一个 Python 可迭代对象，为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Efm4mLzkrCxp"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from C:\\Users\\Awebone\\AppData\\Local\\Temp\\tmpst2po1mq\\model.ckpt-5000\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "Prediction is \"Setosa\" (76.6%), expected \"Setosa\"\n",
      "Prediction is \"Versicolor\" (47.8%), expected \"Versicolor\"\n",
      "Prediction is \"Virginica\" (59.8%), expected \"Virginica\"\n"
     ]
    }
   ],
   "source": [
    "for pred_dict, expec in zip(predictions, expected):\n",
    "    class_id = pred_dict['class_ids'][0]\n",
    "    probability = pred_dict['probabilities'][class_id]\n",
    "\n",
    "    print('Prediction is \"{}\" ({:.1f}%), expected \"{}\"'.format(\n",
    "        SPECIES[class_id], 100 * probability, expec))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "premade.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
